其他
半精度(FP16)调试血泪总结
排除法查找节点
2.1 卷积层映射
def warp_conv(x, conv, factor: int=32):
"""(W*(x/p)+b)*p-b*(p-1) == Wx+b
conv(x) == warp_conv(x, conv)
"""
x_tmp = conv(x / factor)
return factor * x_tmp - (factor - 1) * conv.bias.reshape(
1, -1, 1, 1).repeat(1, 1, x_tmp.size(2), x_tmp.size(3))
2.2 归一化层映射
# input:x, output:out
w = weight / torch.sqrt(running_var + eps)
out = x * w + (bias - running_mean * w)
def warp_bn(x, bn, factor: int=32):
import torch
scale = bn.weight / torch.sqrt(bn.running_var + bn.eps)
bias = bn.bias - bn.running_mean * scale
bias_t = bias.reshape(1, -1, 1, 1).repeat(1, 1, x.size(2), x.size(3))
return bn(x / factor) * factor - (factor - 1) * bias_t
2.3 验证
print(conv(x) - warp_conv(x, conv).sum())
print(bn(x) - warp_bn(x, bn).sum())
整理思路
3.1 数值溢出的可能形式
某个算子内部计算过程数值溢出,输入输出均可以用 fp16 表示 跨内部连续多个算子出现数值溢出 整个网络计算过程都有数值溢出
3.2 多算子数值溢出
寻找fp16失效的算子
每次提前返回结果,二分地导出 ONNX 再导出 TensorRT 模型,未被导出的部分继续以 PyTorch 代码衔接到 TensoRT 的计算结果后。
直接运行 PyTorch 模型,设置断点,查看哪些计算过程有数值异常地大。
4.1 断点调试
4.2 优化
from pyclbr import Function
from typing import Sequence
import torch
def fp16_check(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor) -> None:
if isinstance(input, dict):
for _, value in input.items():
fp16_check(module, value, output)
return
if isinstance(input, Sequence):
for value in input:
fp16_check(module, value, output)
return
if isinstance(output, dict):
for _, value in output.items():
fp16_check(module, input, value)
return
if isinstance(output, Sequence):
for value in output:
fp16_check(module, input, value)
return
if torch.abs(input).max()<65504 and torch.abs(output).max()>65504:
print('from: ', module.finspect_name)
if torch.abs(input).max()>65504 and torch.abs(output).max()<65504:
print('to: ', module.finspect_name)
return
from contextlib import contextmanager
class FInspect:
module_names = ['model']
handlers = []
def hook_all_impl(cls, module: torch.nn.Module, hook_func: Function)-> None:
for name, child in module.named_children():
cls.module_names.append(name)
cls.hook_all_impl(cls, module=child, hook_func=hook_func)
linked_name='->'.join(cls.module_names)
setattr(module, 'finspect_name', linked_name)
cls.module_names.pop()
handler = module.register_forward_hook(hook=hook_func)
cls.handlers.append(handler)
@classmethod
@contextmanager
def hook_all(cls, module: torch.nn.Module, hook_func: Function)-> None:
cls.hook_all_impl(cls, module, hook_func)
yield
[i.remove() for i in cls.handlers]
with FInspect.hook_all(patched_model, fp16_check):
patched_model(inputs)
尝试映射
5.1 relu一生之敌
5.2 公式
5.3 实施
mmocr.models.textdet.necks.FPEM_FFM.forward
mmdet.models.backbones.resnet.BasicBlock.forward
import torch
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils.constants import Backend
FACTOR = 32
ENABLE = False
CHANNEL_THRESH = 400
@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textdet.necks.FPEM_FFM.forward',
backend=Backend.TENSORRT.value)
def fpem_ffm__forward__trt(ctx, self, x, *args, **kwargs):
c2, c3, c4, c5 = x
# reduce channel
c2 = self.reduce_conv_c2(c2)
c3 = self.reduce_conv_c3(c3)
c4 = self.reduce_conv_c4(c4)
if ENABLE:
bn_w = self.reduce_conv_c5[1].weight / torch.sqrt(
self.reduce_conv_c5[1].running_var + self.reduce_conv_c5[1].eps)
bn_b = self.reduce_conv_c5[
1].bias - self.reduce_conv_c5[1].running_mean * bn_w
bn_w = bn_w.reshape(1, -1, 1, 1).repeat(1, 1, c5.size(2), c5.size(3))
bn_b = bn_b.reshape(1, -1, 1, 1).repeat(1, 1, c5.size(2), c5.size(3))
conv_b = self.reduce_conv_c5[0].bias.reshape(1, -1, 1, 1).repeat(
1, 1, c5.size(2), c5.size(3))
c5 = FACTOR * (self.reduce_conv_c5[:-1](c5)) - (FACTOR - 1) * (
bn_w * conv_b + bn_b)
c5 = self.reduce_conv_c5[-1](c5)
else:
c5 = self.reduce_conv_c5(c5)
# FPEM
for i, fpem in enumerate(self.fpems):
c2, c3, c4, c5 = fpem(c2, c3, c4, c5)
if i == 0:
c2_ffm = c2
c3_ffm = c3
c4_ffm = c4
c5_ffm = c5
else:
c2_ffm += c2
c3_ffm += c3
c4_ffm += c4
c5_ffm += c5
# FFM
c5 = F.interpolate(
c5_ffm,
c2_ffm.size()[-2:],
mode='bilinear',
align_corners=self.align_corners)
c4 = F.interpolate(
c4_ffm,
c2_ffm.size()[-2:],
mode='bilinear',
align_corners=self.align_corners)
c3 = F.interpolate(
c3_ffm,
c2_ffm.size()[-2:],
mode='bilinear',
align_corners=self.align_corners)
outs = [c2_ffm, c3, c4, c5]
return tuple(outs)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.resnet.BasicBlock.forward',
backend=Backend.TENSORRT.value)
def basic_block__forward__trt(ctx, self, x):
if self.conv1.in_channels < CHANNEL_THRESH:
return ctx.origin_func(self, x)
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
if torch.abs(self.norm2(out)).max() < 65504:
out = self.norm2(out)
out += identity
out = self.relu(out)
return out
else:
global ENABLE
ENABLE = True
# the output of the last bn layer exceeds the range of fp16
w1 = self.norm2.weight / torch.sqrt(self.norm2.running_var +
self.norm2.eps)
bias = self.norm2.bias - self.norm2.running_mean * w1
w1 = w1.reshape(1, -1, 1, 1).repeat(1, 1, out.size(2), out.size(3))
bias = bias.reshape(1, -1, 1, 1).repeat(1, 1, out.size(2),
out.size(3)) + identity
out = self.relu(w1 * (out / FACTOR) + bias / FACTOR)
return out
5.4 总结
一个快速查找数值溢出算子的方法。 一个替换多个算子,从原始模型解决 FP16 数值溢出的方法。
更多阅读
#投 稿 通 道#
让你的文字被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析、科研心得或竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。
📝 稿件基本要求:
• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注
• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题
• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算
📬 投稿通道:
• 投稿邮箱:hr@paperweekly.site
• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者
• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿
△长按添加PaperWeekly小编
🔍
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧